Skip to content

TypeTree support in autodiff #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open

TypeTree support in autodiff #144197

wants to merge 20 commits into from

Conversation

KMJ-007
Copy link
Contributor

@KMJ-007 KMJ-007 commented Jul 19, 2025

TypeTrees for Autodiff

What are TypeTrees?

Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

Structure

TypeTree(Vec<Type>)

Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}

Example: fn compute(x: &f32, data: &[f32]) -> f32

Input 0: x: &f32

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: 0, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])

Input 1: data: &[f32]

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])

Output: f32

TypeTree(vec![Type {
    offset: 0, size: 4, kind: Float,
    child: TypeTree::new()
}])

Why Needed?

  • Enzyme can't deduce complex type layouts from LLVM IR
  • Prevents slow memory pattern analysis
  • Enables correct derivative computation for nested structures
  • Tells Enzyme which bytes are differentiable vs metadata

What Enzyme Does With This Information:

Without TypeTrees (current state):

; Enzyme sees generic LLVM IR:

define float @distance(i8* %p1, i8* %p2) {

; Has to guess what these pointers point to

; Slow analysis of all memory operations

; May miss optimization opportunities

}

With TypeTrees (our goal):

// Enzyme knows:

// - %p1 points to struct with f32 at +0, f32 at +4, i32 at +8

// - Only the f32 fields need derivatives

// - Can generate efficient derivative code directly

TypeTrees - Offset and -1 Explained

Type Structure

Type {

offset: isize, // WHERE this type starts

size: usize, // HOW BIG this type is

kind: Kind, // WHAT KIND of data (Float, Int, Pointer)

child: TypeTree // WHAT'S INSIDE (for pointers/containers)

}

Offset Values

Regular Offset (0, 4, 8, etc.)

Specific byte position within a structure

struct Point {

x: f32, // offset 0, size 4

y: f32, // offset 4, size 4

id: i32, // offset 8, size 4

}

TypeTree for &Point:

TypeTree(vec![

	Type { offset: 0, size: 4, kind: Float }, // x at byte 0

	Type { offset: 4, size: 4, kind: Float }, // y at byte 4

	Type { offset: 8, size: 4, kind: Integer } // id at byte 8

])

Offset -1 (Special: "Everywhere")

Means "this pattern repeats for ALL elements"

Example 1: Array [f32; 100]

TypeTree(vec![Type {

offset: -1, // ALL positions

size: 4, // each f32 is 4 bytes

kind: Float, // every element is float

}])

Instead of listing 100 separate Types with offsets 0,4,8,12...396

Example 2: Slice &[i32]

// Pointer to slice data

TypeTree(vec![Type {

	offset: 0, size: 8, kind: Pointer,

	child: TypeTree(vec![Type {

	offset: -1, // ALL slice elements

	size: 4, // each i32 is 4 bytes

	kind: Integer

	}])

}])

Example 3: Mixed Structure

struct Container {

	header: i64, // offset 0

	data: [f32; 1000], // offset 8, but elements use -1

}
TypeTree(vec![

	Type { offset: 0, size: 8, kind: Integer }, // header

	Type { offset: 8, size: 4000, kind: Pointer,

	child: TypeTree(vec![Type {

	offset: -1, size: 4, kind: Float // ALL array elements

}])

}

])

@rustbot rustbot added F-autodiff `#![feature(autodiff)]` S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. T-compiler Relevant to the compiler team, which will review and decide on the PR/issue. labels Jul 19, 2025
@rust-log-analyzer

This comment has been minimized.

@rustbot rustbot added the A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. label Jul 19, 2025
@rust-log-analyzer

This comment has been minimized.

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 19, 2025

Currently, I have implemented only for memcpy

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 19, 2025

r? @ZuseZ4

@KMJ-007 KMJ-007 marked this pull request as ready for review July 19, 2025 23:50
@rustbot rustbot added S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. and removed S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. labels Jul 19, 2025
@rustbot
Copy link
Collaborator

rustbot commented Jul 19, 2025

Some changes occurred in compiler/rustc_ast/src/expand/autodiff_attrs.rs

cc @ZuseZ4

Some changes occurred in compiler/rustc_codegen_llvm/src/builder/autodiff.rs

cc @ZuseZ4

Some changes occurred in compiler/rustc_codegen_ssa

cc @WaffleLapkin

Some changes occurred in compiler/rustc_monomorphize/src/partitioning/autodiff.rs

cc @ZuseZ4

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rustbot
Copy link
Collaborator

rustbot commented Jul 20, 2025

Some changes occurred in compiler/rustc_codegen_gcc

cc @antoyo, @GuillaumeGomez

@rust-log-analyzer

This comment has been minimized.

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 21, 2025

CI is failing, fixing them!

@rustbot
Copy link
Collaborator

rustbot commented Jul 23, 2025

Some changes occurred in src/tools/enzyme

cc @ZuseZ4

@rustbot

This comment has been minimized.

@rustbot rustbot added has-merge-commits PR has merge commits, merge with caution. S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. labels Jul 23, 2025
@rustbot

This comment has been minimized.

@rustbot rustbot removed the has-merge-commits PR has merge commits, merge with caution. label Jul 23, 2025
@bors
Copy link
Collaborator

bors commented Aug 6, 2025

☔ The latest upstream changes (presumably #143684) made this pull request unmergeable. Please resolve the merge conflicts.

@rust-log-analyzer

This comment has been minimized.

Comment on lines +4 to +39
use run_make_support::{llvm_filecheck, rfs, rustc};

fn main() {
// First, compile to LLVM IR to check for enzyme_type attributes
let _ir_output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("--emit=llvm-ir")
.arg("-o")
.arg("main.ll")
.run();

// Then compile with TypeTree analysis output for the existing checks
let output = rustc()
.input("memcpy.rs")
.arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
.arg("-Zautodiff=NoPostopt")
.opt_level("3")
.arg("-Clto=fat")
.arg("-g")
.run();

let stdout = output.stdout_utf8();
let stderr = output.stderr_utf8();
let ir_content = rfs::read_to_string("main.ll");

rfs::write("memcpy.stdout", &stdout);
rfs::write("memcpy.stderr", &stderr);
rfs::write("main.ir", &ir_content);

llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();

llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jieyouxu can you review this run make test, i am creating two check files to test IR and type analysis from enzyme, is this correct way or i should be combining them in one file with one single check?

@KMJ-007 KMJ-007 requested a review from ZuseZ4 August 9, 2025 06:59
KMJ-007 added 20 commits August 9, 2025 13:22
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
@bors
Copy link
Collaborator

bors commented Aug 12, 2025

☔ The latest upstream changes (presumably #145300) made this pull request unmergeable. Please resolve the merge conflicts.

@@ -2194,3 +2201,335 @@ pub struct DestructuredConst<'tcx> {
pub variant: Option<VariantIdx>,
pub fields: &'tcx [ty::Const<'tcx>],
}

// Some types are used a lot. Make sure they don't unintentionally get bigger.
#[cfg(target_pointer_width = "64")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably code from someone else, can we drop it?

}

#[cfg(not(llvm_enzyme))]
#[allow(dead_code)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

@@ -1253,6 +1253,9 @@ pub struct Resolver<'ra, 'tcx> {
// that were encountered during resolution. These names are used to generate item names
// for APITs, so we don't want to leak details of resolution into these names.
impl_trait_names: FxHashMap<NodeId, Symbol>,

/// Mapping of autodiff function IDs
autodiff_map: FxHashMap<LocalDefId, LocalDefId>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually use the autodiff_map anywhere?

_ => panic!(""),
};

let fields = adt_def.all_fields();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let fields = adt_def
                .all_fields()
                .into_iter()
                .zip(offsets.into_iter())

if inner_ty.is_slice() {
// We know that the length will be passed as extra arg.
let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span);
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think on 32 bit targets we probably have size: 4 here, can you check that (e.g. on the playground) and make it selective if needed?

FncTree { args, ret }
}

fn typetree_from_ty<'a>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc I added the visited Vec here to detect cycles (recursive types). Can you verify that and add a comment describing what it does?

assert!(span.is_some());
let span = span.unwrap();

tcx.sess.dcx().emit_warn(AutodiffUnsafeInnerConstRef { span, ty });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also have a test

match ty {
x if x == tcx.types.f32 => (Kind::Float, 4),
x if x == tcx.types.f64 => (Kind::Double, 8),
_ => panic!("floatTy scalar that is neither f32 nor f64"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we support f16 and f128 now, can you add those (And some tests)?

// Not an error, because it only causes issues if they are actually read, which we don't check
// yet. We should add such analysis to relibably either issue an error or accept without warning.
// If there only were some research to do that...
pub fn fnc_typetrees<'tcx>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fn is comparably complex, so we should have tests which cover every case handled here, or you split some of the logic out in follow-up PRs.

I would probably squash the current PR into 1-3 commits to start and only add a very minimal handling to fnc_typtrees (and other functions in this file). Then you can add follow-up commits for the extra handling. E.g. one commit for array handling + array tests, one commit for simd and the simd tests, one for recursive handling, and one for the InnerConstRef handling, etc. This way we can merge it incrementally and we're sure that every piece in this function actually works and is tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about the same, should i do follow PRs adding new handling and tests or we have some basic typetree requirement which this PR must complete, so this can be merged.

What should be goal of the current PR ? basic typetree support, working memcpy with tests and i think NOTT flag, right?

@ZuseZ4
Copy link
Member

ZuseZ4 commented Aug 13, 2025

You should also add a Flag which disables all TypeTree additions (in case that it causes bugs, or just for A/B testing to see where the typetree's allow us to compile something, or have a compile time impact). We already have autodiff=Enable, just add another option to that enum, so people can pass -Zautodiff=Enable,NoTT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. A-run-make Area: port run-make Makefiles to rmake.rs F-autodiff `#![feature(autodiff)]` S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. T-compiler Relevant to the compiler team, which will review and decide on the PR/issue.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants